"""
polar : polar coordinate
angle: the orientation of the original image
"""

import argparse
from collections import defaultdict

from torch import optim
from torch.utils.data import DataLoader
import wandb
from tqdm import trange

from disvae.models.anneal import get_anneal
from disvae.models.losses import get_loss_s
from disvae.models.vae import VAE, init_specific_model
from disvae.models.encoders import get_encoder
from disvae.models.decoders import get_decoder

from utils.visualize import *
from exps.patterns import data_generator
import numpy as np
from exps.translation import train, get_preds
from utils.helpers import set_seed


def compute_threshold(action, orientation):
    params = dict(
        batch_size=20,
        epochs=1000,
        dimension=10,
        loss='betaH',
        orientation = orientation
    )
    wandb.init(project="exps", config=params, group='rotation_thresholds', reinit=True)
    config = wandb.config
    # generate data

    epochs = config.epochs
    dim = config.dimension

    dl = DataLoader(action, pin_memory=True,
                    batch_size=config.batch_size, shuffle=True)

    vae = init_specific_model("Higgins", (1, 64, 64), dim, 5)
    vae.cuda()
    vae.train()

    # wandb.watch(vae, log_freq=10)
    iterations = len(dl) * epochs
    anneal = get_anneal('monotonic', iterations, 120, 1)

    loss_f = get_loss_s("betaH", anneal, 1, len(dl.dataset))

    storer = train(dl, vae, loss_f, epochs, 500, "cuda")
    vae.eval()
    kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' == k[:3]])
    _, index = kl_loss.sort(descending=True)

    preds, targets = get_preds(vae, DataLoader(action,batch_size=100))
    data = preds.cpu()[:,index[0]]
    fig = plt.figure()
    plt.plot(data,'x')
    wandb.log({'projection':wandb.Image(fig)})

    vae.cpu()
    wandb.join()
    return storer


img = torch.load(f'patterns/0.pat')
dataset = data_generator.PolarTranslation(img)
for s in range(5):
    for i, factor in enumerate(
            [(torch.tensor(
                [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
                 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39]),
              torch.tensor(
                  [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
                   20, 20, 20,
                   20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20])),
                (torch.tensor(
                    [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
                     20, 20, 20,
                     20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20]),
                 torch.tensor(
                     [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
                      27,
                      28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39]))]
    ):
        x, y = factor
        imgs = dataset.o_imgs[x, y].unsqueeze(1).clone()
        action = list(zip(imgs, torch.stack((x, y), 1)))
        storer = compute_threshold(action, i + 10)
            # if storer['KL_loss'] > 0.1:
            #     break
